#include <stdio.h>

#include <iostream>
#include <fstream>
#include <sstream>

#include <cryptopp/aes.h>
#include <cryptopp/filters.h>
#include <cryptopp/modes.h>

#include "AES.h"

using namespace std;

byte key[ CryptoPP::AES::DEFAULT_KEYLENGTH ], iv[ CryptoPP::AES::BLOCKSIZE];

void MyAES::initKV()
{
    memset( key, 0x00, CryptoPP::AES::DEFAULT_KEYLENGTH );
    memset( iv, 0x00, CryptoPP::AES::BLOCKSIZE );
}

string MyAES::encrypt(string plainText)
{
    string cipherText;

    CryptoPP::AES::Encryption aesEncryption(key, CryptoPP::AES::DEFAULT_KEYLENGTH);
    CryptoPP::CBC_Mode_ExternalCipher::Encryption cbcEncryption( aesEncryption, iv );
    CryptoPP::StreamTransformationFilter stfEncryptor(cbcEncryption, new CryptoPP::StringSink( cipherText ));
    stfEncryptor.Put( reinterpret_cast<const unsigned char*>( plainText.c_str() ), plainText.length() + 1 );
    stfEncryptor.MessageEnd();

    string cipherTextHex;
    for( int i = 0; i < cipherText.size(); i++ )
    {
        char ch[3] = {0};
        sprintf(ch, "%02x",  static_cast<byte>(cipherText[i]));
        cipherTextHex += ch;
    }
    
    return cipherTextHex;
}

string MyAES::decrypt(string cipherTextHex)
{
    string cipherText;
    string decryptedText;

    int i = 0;
    while(true)
    {
        char c;
        int x;
        stringstream ss;
        ss<<hex<<cipherTextHex.substr(i, 2).c_str();
        ss>>x;
        c = (char)x;
        cipherText += c;
        if(i >= cipherTextHex.length() - 2)break;
        i += 2;
    }

    CryptoPP::AES::Decryption aesDecryption(key, CryptoPP::AES::DEFAULT_KEYLENGTH);
    CryptoPP::CBC_Mode_ExternalCipher::Decryption cbcDecryption( aesDecryption, iv );
    CryptoPP::StreamTransformationFilter stfDecryptor(cbcDecryption, new CryptoPP::StringSink( decryptedText ));
    stfDecryptor.Put( reinterpret_cast<const unsigned char*>( cipherText.c_str() ), cipherText.size());

    stfDecryptor.MessageEnd();

    return decryptedText;
}

int main()
{
    MyAES* aes = new MyAES();
    char text[100];
    cout<<"Give an input: ";
    scanf("%[^\n]",text);// \n作为字符串输入的结束符
    cout<<"text : "<<text<<endl;
    // text = "";

    // encrypt
    aes->initKV();
    string cipherHex = aes->encrypt(text);
    cout<<"cipher : "<<cipherHex<<endl;

    // decrypt
    string ciphertext = aes->decrypt(cipherHex);
    cout<<"ciphertext : "<<ciphertext<<endl;

    return 0;
}